# -*- coding: utf-8 -*-
import os
import numpy as np
import arcpy
from sklearn import preprocessing
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix,classification_report,accuracy_score,recall_score
from preprocessing.landslide_preprocessing import prepare_data, split_to_xy,plot_roc_curve

# 参数设置
def execute_prediction(data_filename= 'RasterUnitsV4_2003_2003.csv',
                       year = 2003,
                       GeoID_name = 'OBJECTID',
                       label_prefix ='y_isLandslide_',
                       data_version = 'V4',
                       clf_name = 'logistic',
                       process_unit = 'raster',
                       is_plot_roc = True,
                       test_percent = 0.3,
                       results_dir = None):
    # 保存结果至源代码目录下是results文件夹中
    if results_dir is None:
        current_file_dir = os.path.split(os.path.realpath(__file__))[0]
        results_dir = current_file_dir + "/results"
    arcpy.AddMessage("Results saved directory: 【{}】".format(results_dir))
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # 标签列名称
    if process_unit == "raster":
        lable_name = "{0}{1}".format(label_prefix, year)
    elif process_unit == "slope":
        lable_name = "{0}{1}{2}".format(label_prefix, year, "_MAJORITY")

    # Step 1：将数据分为训练数据集和测试数据集
    train_df, test_df = prepare_data(data_filename, test_percent, test_percent, lable_name)

    # Step 2：将数据集分割为x和y的形式，X为属性特征，y为目标标签
    train_X, train_y = split_to_xy(train_df,
                                   class_col_name=lable_name,
                                   normalized = True)
    test_X, test_y = split_to_xy(test_df,
                                 class_col_name=lable_name,
                                 normalized=True)

    # Step 3：二值化y标签，Binarize the output
    train_y = preprocessing.label_binarize(train_y, classes=[0, 1])
    test_y = preprocessing.label_binarize(test_y, classes=[0, 1])

    # Step 4: 构建分类器
    clf = None
    if clf_name == 'svm':
        clf = SVC(kernel='linear', probability=True, class_weight="balanced")
    else:
        clf = LogisticRegression(class_weight='balanced')

    y_score = clf.fit(train_X, train_y.ravel()).decision_function(test_X)

    # Step 4: 绘制ROC曲线
    if is_plot_roc:
        auc_val = plot_roc_curve(year, test_y, y_score, pos_label=1,
                               classifier_name=clf_name,
                               process_units=process_unit,
                               results_dir=results_dir,
                               data_version= data_version)
        arcpy.AddMessage("AUC:{0:.3f}".format(auc_val))

    # Step 5：使用所有数据，得到预测结果
    all_df = train_df.append(test_df)
    all_X, all_y = split_to_xy(all_df, class_col_name=lable_name, normalized=True)
    predicted = clf.predict(all_X)

    # Step 6：精度评价
    accuracy = accuracy_score(all_y, predicted)
    recall_val = recall_score(all_y, predicted)
    c= confusion_matrix(all_y, predicted, labels=[0, 1])
    arcpy.AddMessage("\t accuracy: {0:.3f}".format(accuracy))
    arcpy.AddMessage("\t recall:{0:.3f}".format(recall_val))
    arcpy.AddMessage("\t confusion_matrix: ")
    arcpy.AddMessage(c)

    # Step 7: 输出分类概率值
    predicted_prob = clf.predict_proba(all_X)
    GeoID = all_df[GeoID_name].values
    results = np.vstack((GeoID, predicted, predicted_prob[:, 1]))
    results = np.transpose(results)

    # Step 8: 保存结果至csv文件
    saved_filename = "{0}/{1}_{2}_{3}_{4}.csv".format(results_dir, process_unit, clf_name, data_version, year)
    header_string = "GeoID,Predicted_Y,Prob_{0}".format(year)
    np.savetxt(saved_filename, results,
               header=header_string, fmt="%d,%d,%0.5f",
               delimiter=",")
    return True, accuracy, auc_val, recall_val

if __name__=="__main__":
    input_csv = arcpy.GetParameterAsText(0)
    year = int(arcpy.GetParameterAsText(1))
    ID_name = arcpy.GetParameterAsText(2)
    label_prefix = arcpy.GetParameterAsText(3)
    data_ver = arcpy.GetParameterAsText(4)
    clf = arcpy.GetParameterAsText(5)
    process_unit = arcpy.GetParameterAsText(6)
    is_plot_roc = arcpy.GetParameter(7)
    test_per = float(arcpy.GetParameterAsText(8))

    arcpy.AddMessage(test_per)
    is_success,accuracy, auc_val, recall_val = execute_prediction(data_filename=input_csv,
                                                            year=year,
                                                            GeoID_name=ID_name,
                                                            label_prefix=label_prefix,
                                                            data_version=data_ver,
                                                            clf_name=clf,
                                                            process_units=process_unit,
                                                            is_plot_roc=is_plot_roc,
                                                            test_percent=test_per)


    # arcpy.env.workspace = "G:/DataForDoctorPaper/博士论文数据.gdb"
    # execute_prediction(
    #     r"C:\Users\Luoge\PycharmProjects\ArcPy_FactorsExtractor\results\RasterUnitsV4_2010.csv", 2010, "OBJECTID",
    #     "y_isLandslide_", "V4", "logistic", "raster", True, 0.3)

    # execute_prediction(data_filename='RasterUnitsV4_2003_2003.csv',
    #                    year=2003,
    #                    GeoID_name='OBJECTID',
    #                    label_prefix='y_isLandslide_',
    #                    data_version='V4',
    #                    clf_name='logistic',
    #                    process_units='raster',
    #                    is_plot_roc=True,
    #                    test_percent=0.3)